"""
Phase 1: Data Preparation — Crisis & Sudden Stop Variables
==========================================================
Download Laeven & Valencia (2018) systemic crisis database,
construct crisis indicators, sudden stop / CA reversal variables,
and early warning variables. Merge with 69-country demographic panel.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import zipfile
import io
import re
import warnings
warnings.filterwarnings('ignore')

try:
    import requests
except ImportError:
    requests = None

try:
    import openpyxl
except ImportError:
    openpyxl = None

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_DATA.mkdir(parents=True, exist_ok=True)
OUT_TABLES.mkdir(parents=True, exist_ok=True)

# ── Laeven & Valencia crisis database ────────────────────────────────

LV_URL = "https://www.imf.org/-/media/files/publications/wp/2018/datasets/wp18206.zip"
LV_CACHE = OUT_DATA / "wp18206.zip"

# Manual country name → ISO3 mapping for L&V edge cases
COUNTRY_FIXES = {
    "Korea": "KOR",
    "Korea, Rep.": "KOR",
    "Korea, Republic of": "KOR",
    "Republic of Korea": "KOR",
    "Russia": "RUS",
    "Russian Federation": "RUS",
    "Venezuela": "VEN",
    "Venezuela, RB": "VEN",
    "Iran": "IRN",
    "Iran, Islamic Rep.": "IRN",
    "Egypt": "EGY",
    "Egypt, Arab Rep.": "EGY",
    "Czech Republic": "CZE",
    "Czechia": "CZE",
    "Slovak Republic": "SVK",
    "Slovakia": "SVK",
    "Ivory Coast": "CIV",
    "Côte d'Ivoire": "CIV",
    "Cote d'Ivoire": "CIV",
    "Congo, Dem. Rep.": "COD",
    "Congo, Democratic Republic of": "COD",
    "Democratic Republic of the Congo": "COD",
    "Congo, Rep.": "COG",
    "Congo, Republic of": "COG",
    "Republic of Congo": "COG",
    "Lao PDR": "LAO",
    "Lao P.D.R.": "LAO",
    "Laos": "LAO",
    "Vietnam": "VNM",
    "Viet Nam": "VNM",
    "Bolivia": "BOL",
    "Tanzania": "TZA",
    "Myanmar": "MMR",
    "Burma": "MMR",
    "Kyrgyz Republic": "KGZ",
    "Kyrgyzstan": "KGZ",
    "Macedonia": "MKD",
    "North Macedonia": "MKD",
    "FYR Macedonia": "MKD",
    "Swaziland": "SWZ",
    "Eswatini": "SWZ",
    "Cape Verde": "CPV",
    "Cabo Verde": "CPV",
    "Turkiye": "TUR",
    "Turkey": "TUR",
    "United States": "USA",
    "United Kingdom": "GBR",
    "UK": "GBR",
    "US": "USA",
    "China": "CHN",
    "China, P.R.": "CHN",
    "China, People's Republic of": "CHN",
}


def country_to_iso3(name):
    """Map country name to ISO3, with manual fixes."""
    if not isinstance(name, str):
        return None
    name = name.strip()
    if name in COUNTRY_FIXES:
        return COUNTRY_FIXES[name]
    try:
        import pycountry
        # Try exact match
        c = pycountry.countries.get(name=name)
        if c:
            return c.alpha_3
        # Try fuzzy
        results = pycountry.countries.search_fuzzy(name)
        if results:
            return results[0].alpha_3
    except Exception:
        pass
    return None


def download_lv_database(force=False):
    """Download and cache Laeven & Valencia crisis database."""
    if LV_CACHE.exists() and not force:
        print(f"  Using cached: {LV_CACHE}")
        return LV_CACHE

    if requests is None:
        print("  WARNING: requests not available, using fallback crisis data")
        return None

    print(f"  Downloading L&V database from IMF...")
    resp = requests.get(LV_URL, timeout=60)
    if resp.status_code != 200:
        print(f"  WARNING: Download failed (status {resp.status_code}), using fallback")
        return None

    LV_CACHE.write_bytes(resp.content)
    print(f"  Saved: {LV_CACHE} ({len(resp.content)/1024:.0f} KB)")
    return LV_CACHE


def parse_year_string(s):
    """Parse year(s) from L&V cell — may contain '1997', '1988, 1994', 'ongoing', etc."""
    if pd.isna(s):
        return []
    s = str(s).strip()
    if s.lower() in ('', 'nan', 'none', '…', '-'):
        return []
    # Extract all 4-digit years
    years = re.findall(r'\b(19\d{2}|20[012]\d)\b', s)
    return [int(y) for y in years]


def parse_lv_excel(zip_path):
    """Parse L&V Excel from zip, extract crisis episodes.

    L&V 2018 format:
    - 'Crisis Years' sheet: Country | Banking Start | Currency | Sovereign | Restructuring
    - 'Crisis Resolution and Outcomes': Country | Start | End | Output loss | ...
    """
    episodes = []

    with zipfile.ZipFile(zip_path) as zf:
        excel_files = [f for f in zf.namelist() if f.endswith('.xlsx') or f.endswith('.xls')]
        if not excel_files:
            print(f"  WARNING: No Excel files in zip. Contents: {zf.namelist()}")
            return pd.DataFrame()

        excel_name = excel_files[0]
        print(f"  Parsing: {excel_name}")

        with zf.open(excel_name) as ef:
            xls = pd.ExcelFile(io.BytesIO(ef.read()), engine='openpyxl')
            print(f"  Sheets: {xls.sheet_names}")

            # ── Parse 'Crisis Resolution and Outcomes' for banking start/end ──
            banking_durations = {}  # (iso3, start_year) → end_year
            for sheet in xls.sheet_names:
                if 'resolution' in sheet.lower() or 'outcome' in sheet.lower():
                    df_res = pd.read_excel(xls, sheet_name=sheet)
                    cols = df_res.columns.tolist()
                    country_col = cols[0]  # 'Country'
                    start_col = cols[1]    # 'Start'
                    end_col = cols[2]      # 'End'

                    for _, row in df_res.iterrows():
                        country = row[country_col]
                        if pd.isna(country):
                            continue
                        iso3 = country_to_iso3(str(country))
                        starts = parse_year_string(row[start_col])
                        ends = parse_year_string(row[end_col])
                        if iso3 and starts:
                            start_yr = starts[0]
                            end_yr = ends[0] if ends else start_yr
                            banking_durations[(iso3, start_yr)] = end_yr
                    print(f"  Resolution sheet: {len(banking_durations)} banking episodes with duration")
                    break

            # ── Parse 'Crisis Years' sheet for all crisis types ──
            for sheet in xls.sheet_names:
                if 'crisis year' in sheet.lower():
                    df_cy = pd.read_excel(xls, sheet_name=sheet)
                    cols = df_cy.columns.tolist()
                    print(f"  Crisis Years columns: {cols}")

                    # Columns: Country, Banking start, Currency, Sovereign, Restructuring
                    country_col = cols[0]
                    banking_col = cols[1] if len(cols) > 1 else None
                    currency_col = cols[2] if len(cols) > 2 else None
                    sovereign_col = cols[3] if len(cols) > 3 else None

                    for _, row in df_cy.iterrows():
                        country = row[country_col]
                        if pd.isna(country) or str(country).strip() == '':
                            continue
                        iso3 = country_to_iso3(str(country))
                        if not iso3:
                            continue

                        # Banking crises
                        if banking_col:
                            years = parse_year_string(row[banking_col])
                            for yr in years:
                                end_yr = banking_durations.get((iso3, yr), yr)
                                episodes.append({
                                    'iso3': iso3,
                                    'crisis_type': 'banking',
                                    'start_year': yr,
                                    'end_year': end_yr,
                                    'country_name': str(country).strip(),
                                })

                        # Currency crises
                        if currency_col:
                            years = parse_year_string(row[currency_col])
                            for yr in years:
                                episodes.append({
                                    'iso3': iso3,
                                    'crisis_type': 'currency',
                                    'start_year': yr,
                                    'end_year': yr,  # currency crises are single-year events
                                    'country_name': str(country).strip(),
                                })

                        # Sovereign debt crises
                        if sovereign_col:
                            years = parse_year_string(row[sovereign_col])
                            for yr in years:
                                episodes.append({
                                    'iso3': iso3,
                                    'crisis_type': 'sovereign',
                                    'start_year': yr,
                                    'end_year': yr,
                                    'country_name': str(country).strip(),
                                })

                    break

    ep_df = pd.DataFrame(episodes)
    if len(ep_df) == 0:
        print("  WARNING: No episodes parsed from L&V Excel")
        return ep_df

    # Deduplicate
    ep_df = ep_df.drop_duplicates(subset=['iso3', 'crisis_type', 'start_year'])

    print(f"  Total episodes parsed: {len(ep_df)}")
    for ct in ['banking', 'currency', 'sovereign']:
        n = len(ep_df[ep_df['crisis_type'] == ct])
        nc = ep_df.loc[ep_df['crisis_type'] == ct, 'iso3'].nunique()
        print(f"    {ct}: {n} episodes in {nc} countries")

    return ep_df


def build_fallback_crisis_data():
    """Fallback: construct crisis indicators from well-known episodes."""
    # Major systemic banking crises from Laeven & Valencia (2018) — key episodes
    banking = [
        # 2007-2011 wave
        ("USA", 2007, 2011), ("GBR", 2007, 2011), ("DEU", 2008, 2010),
        ("FRA", 2008, 2009), ("ITA", 2008, 2011), ("ESP", 2008, 2012),
        ("GRC", 2008, 2012), ("PRT", 2008, 2012), ("IRL", 2008, 2012),
        ("ISL", 2008, 2012), ("NLD", 2008, 2009), ("AUT", 2008, 2012),
        ("BEL", 2008, 2012), ("DNK", 2008, 2009), ("SWE", 2008, 2009),
        ("CHE", 2008, 2009), ("RUS", 2008, 2009), ("UKR", 2008, 2010),
        ("KAZ", 2008, 2010), ("HUN", 2008, 2012), ("LVA", 2008, 2012),
        ("NGA", 2009, 2012),
        # 1990s
        ("JPN", 1997, 2001), ("KOR", 1997, 1998), ("THA", 1997, 2000),
        ("IDN", 1997, 2001), ("MYS", 1997, 1999), ("PHL", 1997, 2001),
        ("MEX", 1994, 1996), ("ARG", 1995, 1995), ("ARG", 2001, 2003),
        ("BRA", 1990, 1994), ("TUR", 2000, 2001), ("RUS", 1998, 1998),
        ("COL", 1998, 2000), ("ECU", 1998, 2002),
        # 1980s-90s Nordic + others
        ("FIN", 1991, 1995), ("SWE", 1991, 1995), ("NOR", 1991, 1993),
        ("CHL", 1981, 1985), ("VEN", 1994, 1998), ("IND", 1993, 1993),
        ("CHN", 1998, 1998), ("CZE", 1996, 2000),
    ]

    currency = [
        ("MEX", 1994, 1994), ("THA", 1997, 1997), ("KOR", 1997, 1997),
        ("IDN", 1997, 1997), ("MYS", 1997, 1997), ("PHL", 1997, 1997),
        ("RUS", 1998, 1998), ("BRA", 1999, 1999), ("ARG", 2002, 2002),
        ("TUR", 2001, 2001), ("ISL", 2008, 2008), ("UKR", 2008, 2008),
        ("GBR", 1992, 1992), ("ITA", 1992, 1992), ("FIN", 1993, 1993),
        ("SWE", 1993, 1993), ("NOR", 1992, 1992), ("VEN", 1994, 1994),
        ("ECU", 1999, 1999), ("ZAF", 2001, 2001), ("EGY", 2003, 2003),
        ("NGA", 2016, 2016), ("GHA", 2014, 2014),
    ]

    sovereign = [
        ("ARG", 2001, 2005), ("RUS", 1998, 2000), ("ECU", 1999, 2000),
        ("GRC", 2012, 2012), ("UKR", 2008, 2009),
        ("MEX", 1982, 1990), ("BRA", 1983, 1994), ("PER", 1983, 1997),
        ("CHL", 1983, 1990), ("COL", 1982, 1990), ("VEN", 1983, 1990),
        ("NGA", 1982, 1992), ("ZAF", 1985, 1993), ("TUR", 1978, 1982),
        ("PHL", 1983, 1992), ("IDN", 1998, 2000),
    ]

    episodes = []
    for crisis_type, crises_list in [('banking', banking), ('currency', currency),
                                      ('sovereign', sovereign)]:
        for iso3, start, end in crises_list:
            episodes.append({
                'iso3': iso3,
                'crisis_type': crisis_type,
                'start_year': start,
                'end_year': end,
                'country_name': iso3,
            })

    return pd.DataFrame(episodes)


def construct_crisis_indicators(panel, episodes):
    """Merge crisis episodes into panel as indicator variables."""
    # Initialize indicators
    for var in ['banking_crisis', 'currency_crisis', 'sovereign_crisis',
                'banking_crisis_onset', 'currency_crisis_onset', 'sovereign_crisis_onset']:
        panel[var] = 0

    # Match episodes to panel
    panel_countries = set(panel['iso3'].unique())
    matched = 0
    unmatched_countries = set()

    for _, ep in episodes.iterrows():
        iso3 = ep['iso3']
        if iso3 not in panel_countries:
            unmatched_countries.add(iso3)
            continue

        ctype = ep['crisis_type']
        start = ep['start_year']
        end = ep['end_year']

        # Set crisis indicator for all years in episode
        mask = (panel['iso3'] == iso3) & (panel['year'] >= start) & (panel['year'] <= end)
        panel.loc[mask, f'{ctype}_crisis'] = 1
        matched += mask.sum()

        # Set onset indicator for start year only
        onset_mask = (panel['iso3'] == iso3) & (panel['year'] == start)
        panel.loc[onset_mask, f'{ctype}_crisis_onset'] = 1

    # Any crisis composite
    panel['any_crisis'] = panel[['banking_crisis', 'currency_crisis', 'sovereign_crisis']].max(axis=1)
    panel['any_crisis_onset'] = panel[['banking_crisis_onset', 'currency_crisis_onset',
                                        'sovereign_crisis_onset']].max(axis=1)

    print(f"\n  Crisis indicators constructed:")
    print(f"    Episodes matched to panel obs: {matched}")
    if unmatched_countries:
        print(f"    Countries not in panel: {sorted(unmatched_countries)[:15]}...")
    for ctype in ['banking', 'currency', 'sovereign', 'any']:
        n_years = panel[f'{ctype}_crisis'].sum()
        n_onset = panel[f'{ctype}_crisis_onset'].sum()
        n_countries = panel.loc[panel[f'{ctype}_crisis'] == 1, 'iso3'].nunique()
        print(f"    {ctype}: {int(n_onset)} onsets, {int(n_years)} crisis-years, {n_countries} countries")

    return panel


def construct_sudden_stops(panel):
    """Construct CA reversal and sudden stop indicators."""
    panel = panel.sort_values(['iso3', 'year'])

    # Year-on-year change in CA/GDP
    panel['d_ca_gdp'] = panel.groupby('iso3')['ca_gdp'].diff()

    # CA reversal: d_ca_gdp ≤ -3pp (Calvo 1998)
    panel['ca_reversal'] = (panel['d_ca_gdp'] <= -3).astype(int)

    # Stricter: ≤ -5pp
    panel['ca_reversal_5pp'] = (panel['d_ca_gdp'] <= -5).astype(int)

    # Sudden stop: CA reversal AND output gap < 0
    if 'output_gap' in panel.columns:
        panel['sudden_stop'] = ((panel['ca_reversal'] == 1) &
                                 (panel['output_gap'] < 0)).astype(int)
    else:
        panel['sudden_stop'] = 0

    print(f"\n  Sudden stop / CA reversal indicators:")
    print(f"    CA reversals (≥3pp): {panel['ca_reversal'].sum()}")
    print(f"    CA reversals (≥5pp): {panel['ca_reversal_5pp'].sum()}")
    print(f"    Sudden stops: {panel['sudden_stop'].sum()}")

    return panel


def construct_early_warning_vars(panel):
    """Construct early warning / control variables."""
    panel = panel.sort_values(['iso3', 'year'])

    # Lagged CA/GDP
    panel['ca_gdp_lag1'] = panel.groupby('iso3')['ca_gdp'].shift(1)

    # Change in gross liabilities (credit boom proxy)
    if 'gross_liab_gdp' in panel.columns:
        panel['d_gross_liab'] = panel.groupby('iso3')['gross_liab_gdp'].diff()
    else:
        panel['d_gross_liab'] = np.nan

    # Reserves to liabilities ratio
    if 'fx_reserves_gdp' in panel.columns and 'gross_liab_gdp' in panel.columns:
        panel['reserves_to_liab'] = panel['fx_reserves_gdp'] / panel['gross_liab_gdp'].replace(0, np.nan)
    else:
        panel['reserves_to_liab'] = np.nan

    # Youth bulge indicator
    if 'youth_dep' in panel.columns:
        median_youth = panel['youth_dep'].median()
        panel['youth_bulge'] = (panel['youth_dep'] > median_youth).astype(int)
        print(f"  Youth bulge cutoff (panel median): {median_youth:.2f}")
    else:
        panel['youth_bulge'] = np.nan

    # NFA lag (already in panel)
    if 'nfa_gdp_lag' not in panel.columns and 'nfa_gdp' in panel.columns:
        panel['nfa_gdp_lag'] = panel.groupby('iso3')['nfa_gdp'].shift(1)

    # Demographic terciles for subgroup analysis
    if 'old_dep' in panel.columns:
        panel['demo_tercile'] = pd.qcut(panel['old_dep'].dropna(), 3,
                                          labels=['early', 'mid', 'late']).astype(str)
        panel.loc[panel['old_dep'].isna(), 'demo_tercile'] = np.nan

    return panel


def write_summary_statistics(panel):
    """Write summary statistics table for crisis variables."""
    crisis_vars = ['banking_crisis', 'currency_crisis', 'sovereign_crisis', 'any_crisis',
                   'banking_crisis_onset', 'any_crisis_onset',
                   'ca_reversal', 'ca_reversal_5pp', 'sudden_stop',
                   'd_ca_gdp', 'ca_gdp_lag1', 'reserves_to_liab', 'youth_bulge']

    rows = []
    for var in crisis_vars:
        if var in panel.columns:
            s = panel[var].dropna()
            rows.append({
                'Variable': var,
                'N': len(s),
                'Mean': f"{s.mean():.4f}",
                'Std': f"{s.std():.4f}",
                'Min': f"{s.min():.4f}",
                'Max': f"{s.max():.4f}",
            })

    lines = ["# Summary Statistics: Crisis & Sudden Stop Variables\n"]
    lines.append("| Variable | N | Mean | Std | Min | Max |")
    lines.append("|:---|---:|---:|---:|---:|---:|")
    for r in rows:
        lines.append(f"| {r['Variable']} | {r['N']} | {r['Mean']} | {r['Std']} | {r['Min']} | {r['Max']} |")

    lines.append(f"\n*Panel: {panel['iso3'].nunique()} countries, "
                 f"{panel['year'].min()}–{panel['year'].max()}*")

    path = OUT_TABLES / "summary_statistics.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 1: DATA PREPARATION — CRISIS & SUDDEN STOP VARIABLES")
    print("=" * 70)

    # Load panel
    fp_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    print(f"\n  Loading: {fp_path}")
    panel = pd.read_csv(fp_path)
    panel = panel[panel['year'] <= 2024].copy()
    print(f"  Panel: {len(panel)} obs, {panel['iso3'].nunique()} countries, "
          f"{panel['year'].min()}–{panel['year'].max()}")

    # Download and parse L&V database
    print("\n--- Laeven & Valencia Crisis Database ---")
    zip_path = download_lv_database()

    if zip_path is not None and zip_path.exists():
        try:
            episodes = parse_lv_excel(zip_path)
            if len(episodes) == 0:
                print("  Parsing returned no episodes, using fallback")
                episodes = build_fallback_crisis_data()
        except Exception as e:
            print(f"  Error parsing L&V: {e}")
            print("  Using fallback crisis data")
            episodes = build_fallback_crisis_data()
    else:
        print("  Using fallback crisis data (hardcoded major episodes)")
        episodes = build_fallback_crisis_data()

    # Save episodes for reference
    episodes.to_csv(OUT_DATA / "lv_episodes.csv", index=False)
    print(f"  Saved episodes: {OUT_DATA / 'lv_episodes.csv'}")

    # Construct indicators
    print("\n--- Crisis Indicators ---")
    panel = construct_crisis_indicators(panel, episodes)

    print("\n--- Sudden Stops & CA Reversals ---")
    panel = construct_sudden_stops(panel)

    print("\n--- Early Warning Variables ---")
    panel = construct_early_warning_vars(panel)

    # Summary statistics
    write_summary_statistics(panel)

    # Save
    out_path = OUT_DATA / "crises_panel.csv"
    panel.to_csv(out_path, index=False)
    print(f"\n  Saved: {out_path}")
    print(f"  Shape: {panel.shape}")
    print(f"  Columns added: banking_crisis, currency_crisis, sovereign_crisis, "
          "any_crisis, ca_reversal, sudden_stop, etc.")

    print("\n" + "=" * 70)
    print("PHASE 1 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
